-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Track generated functions for torch compile #1094
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1094 +/- ##
==========================================
- Coverage 82.12% 82.12% -0.01%
==========================================
Files 183 183
Lines 47986 48016 +30
Branches 8644 8648 +4
==========================================
+ Hits 39409 39433 +24
- Misses 6411 6417 +6
Partials 2166 2166
|
05cbafc
to
d03bfc7
Compare
d03bfc7
to
9bbc90a
Compare
def conversion_func_register(*args, **kwargs): | ||
functor = pytorch_funcify(*args, **kwargs) | ||
module = pytensor.link.utils | ||
setattr(module, kwargs["unique_name"](functor), functor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a weakref perhaps? So memory does get freed if nothing else is using these functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that's a good idea. Let me try doing weakref.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried doing weakref (ala weakref.ref
) in a few spots and couldn't get it to work. Maybe more importantly, it's pretty intrusive (the generated code now has to know it's a weakref, and call it differently; all backends need weakref's, etc etc). I'm not sure this is a good path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem with this is PyTensor becomes a giant memory leak if you compile enough pytensor functions?
Can we get input from the torch devs now that we narrowed down the problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll definitely make an ask to the torch team to figure it out, and post here. Probably before we merge this I imagine?
To the point around giant memory leak, I'm not sure that'll be the case to be honest / I'm not particularly worried. These methods work as is if you disable torch compile, so that tells me that the closures are already somewhere. When profiling the script above with memray, I see that we have allocations / memory use within a normal wiggle room (+-5 allocs, 50kb). Here are the zips.
prof.zip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the issue we end up seeing wrt to weakref
torch._dynamo.exc.InternalTorchDynamoError: weakly-referenced object no longer exists
from user code:
File "/var/folders/98/g1t2_d2x4w94vqfv06xhyz6c0000gp/T/tmpnk75e974", line 3, in pytorch_funcified_fgraph
tensor_variable_2, tensor_variable_3 = pytorch_funcified_fgraph(y, z)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to make the functiongraph that torch does keep around to be the one referening those closures? The link.utils will never go away unless someone does del pytensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can a __del__
method in the torch linker, and have that cleaned up these explicit references?
Additionally yea I'll see if I cheese something where we have the global fgraph hold a reference somewhere....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to add a sort of wrapper execution function that can just allow us to clean up after ourselves. I do make an assumption that each pytensor.function
def will have it's own PytorchLinker
, lmk if that is not correct. This successfully cleaned up the references after the the new function goes out of scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the function will have its own PyTorchLinker
self.gen_functors = copy.copy(gen_functors) | ||
|
||
def __call__(self, *args, **kwargs): | ||
import pytensor.link.utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, there is no way this is threadsafe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks better. Can you add a test showing the extra attributes on the module being discarded when you del the function?
Also can we store them with a leading _? Not to clobber up during debugging?
pytensor/link/pytorch/linker.py
Outdated
@@ -34,3 +71,6 @@ def create_thunk_inputs(self, storage_map): | |||
thunk_inputs.append(sinput) | |||
|
|||
return thunk_inputs | |||
|
|||
def record_fn(self, name, fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually removed this and moved the "conversion" func into the linker, so now the linker is doing all the tracking and stuff. Seemed more reasonable.
And more readable PR title? |
d7b570c
to
0c90a51
Compare
7631a68
to
07e6113
Compare
0d24cdc
to
aa6aac2
Compare
Okay @ricardoV94 if this still looks good, it should be ready to merge. |
Description
The torch compiler was having an issue running pytensor graphs with subgraphs. The way pytorch / module resolution worked was causing pytorch guards to fail during creation. Torch requires the guards to be created, and for them to be correct at runtime, so not being able to make one meant things couldn't get compiled. This tries to bandaid that by putting the modules we generate onto a module, essentially giving an explicit lifecycle to that module. This means we don't need to have graph breaks. That means faster compiler times, and better compiler results.
for this code
with change
without change
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1094.org.readthedocs.build/en/1094/